// LINT: LEGACY_NAMES
syntax = "proto3";

package stream_executor.dnn;

import "google/protobuf/wrappers.proto";

option go_package = "github.com/google/tsl/tsl/go/stream_executor";

// Specifies the data type used by an operation.
enum DataType {
  kFloat = 0;
  kDouble = 1;
  kHalf = 2;
  kInt8 = 3;
  kInt32 = 4;
  kComplexFloat = 5;
  kComplexDouble = 6;
  kBF16 = 7;
  kF8E5M2 = 8;
  kF8E4M3FN = 9;
  kF8E5M2FNUZ = 10;
  kF8E4M3FNUZ = 11;
  kInt64 = 12;
}

// Describes how a convolution input or output layer's data is formatted.
enum DataLayout {
  // Naming convention:
  // Y <-> row or height
  // X <-> column or width
  // Batch <-> batch, or N
  // Depth <-> feature, or channel
  // TODO(timshen): turn them into cuDNN names, e.g. kNCHW.
  //
  // Note: In cudnn, kBatchDepthYX4 and kBatchDepthYX32 are the same layout
  // (namely, NCHW_VECT_C).  It differentiates between these two by using a
  // different data type (int8x4 vs int8x32).  In StreamExecutor we use
  // different layouts for these, because we don't usually pass an explicit data
  // type to StreamExecutor functions.
  kYXDepthBatch = 0;
  kYXBatchDepth = 1;
  kBatchYXDepth = 2;    // cuDNN's NHWC layout
  kBatchDepthYX = 3;    // cuDNN's NCHW layout
  kBatchDepthYX4 = 4;   // cuDNN's NCHW_VECT_C with 4-elem vectors (e.g. int8x4)
  kBatchDepthYX32 = 5;  // cuDNN's NCHW_VECT_C with 32-elem vects (e.g. int8x32)
}

// Describes how a convolution filter is laid out in the memory.
enum FilterLayout {
  // Naming convention:
  // Y <-> row or height
  // X <-> column or width
  // Output <-> output feature, or N
  // Input <-> input feature, or N
  // TODO(timshen): turn them into cuDNN names, e.g. kNCHW.
  kOutputInputYX = 0;    // cuDNN's NCHW layout
  kOutputYXInput = 1;    // cuDNN's NHWC layout
  kOutputInputYX4 = 2;   // cuDNN's NCHW_VECT_C layout with 4-elem vectors
  kOutputInputYX32 = 5;  // cuDNN's NCHW_VECT_C layout with 32-elem vectors
  // cuDNN-specific filter reordering (using `cudnnReorderFilterAndBias`)
  // When the filter is reordered, so is the bias (if present).
  kOutputInputYX32_CudnnReordered = 6;
  kInputYXOutput = 3;
  kYXInputOutput = 4;
}

// Describes a kind of non-linearity (threshold-like mathematical function).
enum ActivationMode {
  kNone = 0;
  kSigmoid = 1;
  // Rectified linear activation: f(x) = x < 0 ? 0 : x
  kRelu = 2;
  // Rectified linear activation; where upper maximum is 6.0.
  kRelu6 = 3;
  // Rectified linear activation; where upper maximum specified by
  // BatchDescriptor::value_max().
  kReluX = 4;
  kTanh = 5;
  // Like ReluX; but passes all values in the range [-X,X].
  kBandPass = 6;
  // Exponential linear activation: f(x) = x < 0 ? e^x - 1 : x
  kElu = 7;
  // Leaky Rectified linear activation: f(x) = x < 0 ? alpha * x : x
  kLeakyRelu = 8;
  // Gaussian Error linear unit activation:
  //   x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2))), where P(X) ~ N(0, 1).
  kGeluExact = 9;
}

// Describe the math definition for the conv op. The popular behavior is
// actually called cross-correlation in math, despite the operation is often
// referred as convolution. See cuDNN cudnnConvolutionMode_t.
enum ConvolutionMode {
  CROSS_CORRELATION = 0;
  CONVOLUTION = 1;
}

enum ConvolutionKind {
  INVALID = 0;
  FORWARD = 1;
  BACKWARD_FILTER = 2;
  BACKWARD_DATA = 3;
  FORWARD_BIAS_ACTIVATION = 4;
  FORWARD_GRAPH = 5;
}

// Generic tensor representation.
message TensorDescriptorProto {
  repeated int64 dimensions = 1;
  DataType data_type = 2;
  oneof layout_oneof {
    DataLayout data_layout = 3;
    FilterLayout filter_layout = 4;
  }
}

// Generic algorithm representation.
message AlgorithmProto {
  enum MathType {
    DEFAULT_MATH = 0;
    // The GPU may operate 4x4 matrix FMA.
    // See cuDNN's documentation for CUDNN_TENSOR_OP_MATH.
    TENSOR_OP_MATH = 1;
  }
  int64 algo_id = 1;
  MathType math_type = 2;
  reserved 3;

  map<int64, int64> tuning_knobs = 4;
  // Legacy algorithm enums and cuDNN Frontend engine numbers need to coexist in
  // the same proto medium-term, until we can be confident of no longer needing
  // the legacy cuDNN convolution API.  Once the migration is complete, we can
  // stop producing legacy algorithm enums and remove this field.
  bool is_cudnn_frontend = 5;

  // For ROCm only, it's impossible to re-query the required workspace size
  // after running the algorithm search, so we must store the workspace size
  // along with the choice of algorithm.  For consistency and convenience,
  // cuDNN uses this field in the same way, even though it would be possible to
  // re-query the workspace size from cuDNN at each use.
  //
  // Since this message is persisted in files, we need to be able to distinguish
  // 0 workspace size from unknown workspace size in an old message, so this is
  // a message field.
  google.protobuf.UInt64Value workspace_size = 6;
}

// Proto definition of AlgorithmConfig in "dnn.h".
// TODO(ruochengw): After cl/380702564 is submitted, add support for algorithm
// configs with cuDNN Frontend APIs.
message AlgorithmConfigProto {
  // Use oneof to emulate optional semantics in proto2 since older
  // version of proto3 cannot distinguish "unset field" and "default field".
  oneof optional_algorithm {
    AlgorithmProto algorithm = 1;
  }
  oneof optional_algorithm_no_scratch {
    AlgorithmProto algorithm_no_scratch = 2;
  }
  oneof optional_scratch_size {
    int64 scratch_size = 3;
  }
}

// Convolution-specific parameters.
message ConvolutionDescriptorProto {
  repeated int64 paddings = 1;
  repeated int64 strides = 2;
  repeated int64 dilations = 3;
  // The "accumulator" type. For example, use F32 as an accumulator for F16
  // convolutions.
  // See cuDNN's cudnnConvolutionMode_t.
  DataType compute_mode = 4;
  // See cuDNN's group count.
  int32 group_count = 5;
  ConvolutionMode convolution_mode = 6;
  // Tensorflow node name, same as in NodeDef, for debugging purposes.
  string name = 7;
}

// FusedMHAKind kind
enum FusedMHAKind {
  BMM1_OUTPUT_UNKNOWN = 0;
  BMM1_OUTPUT_INPUT_TYPE = 1;
  BMM1_OUTPUT_FLOAT = 2;
}
